# -*- coding: utf-8 -*-

from collections import OrderedDict

import shortuuid
import sqlalchemy as sa
from sqlalchemy import Column, Integer, ForeignKey, Table
from sqlalchemy import event
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.ext.declarative import declarative_base, declared_attr
from sqlalchemy.orm import relationship, backref
from sqlalchemy.orm import sessionmaker
from sqlalchemy.sql.expression import FunctionElement

from bopress import settings
from bopress.lib.sqlaljson import JsonSerializableBase
from bopress.log import Logger

__author__ = 'yezang'


@declared_attr
def __tablename__(*args):
    self = args[0]
    table_name = self.__name__.lower()
    return "%s%s" % (settings.TABLE_NAME_PREFIX, table_name)


Entity = declarative_base(cls=(JsonSerializableBase,))
Entity.__tablename__ = __tablename__


def pk():
    """
    生成UUID表主键
    :return:
    """
    return shortuuid.uuid()


class SessionFactory(object):
    _engine = None

    @staticmethod
    def instance():
        if SessionFactory._engine:
            return SessionFactory._engine
        SessionFactory._engine = settings.DBEngine()
        return SessionFactory._engine

    @staticmethod
    def _on_bulk_delete(delete_context):
        pass

    @staticmethod
    def _on_bulk_update(update_context):
        pass

    @staticmethod
    def session(use_evt=False):
        s = sessionmaker(bind=SessionFactory.instance())()
        if use_evt:
            event.listen(s, "after_bulk_delete", SessionFactory._on_bulk_delete)
            event.listen(s, "after_bulk_update", SessionFactory._on_bulk_update)
        return s

    @staticmethod
    def connect():
        return SessionFactory.instance().connect()

    @staticmethod
    def transaction(func, auto_commit=True):
        """
        Raw Sql Transaction
        :param func: callback, func(conn, transaction)
        :param auto_commit: auto commit
        """
        with SessionFactory.connect() as c:
            trans = c.begin()
            try:
                func(c, trans)
                if auto_commit:
                    trans.commit()
            except Exception as e:
                trans.rollback()
                Logger.exception(e)

    @staticmethod
    def create_tables():
        Entity.metadata.create_all(SessionFactory.instance())

    @staticmethod
    def has_data(model_cls):
        """
        方便添加测试数据.
        :param model_cls:
        :return:
        """
        s = SessionFactory.session()
        num = s.query(model_cls).count()
        if num > 0:
            return True
        return False


def get_primary_keys(entity_cls):
    cols = sa.inspect(entity_cls).columns
    return OrderedDict(
        (
            (key, column) for key, column in cols.items()
            if column.primary_key
        )
    )


def get_ref_column(entity_cls):
    cols = get_primary_keys(entity_cls)
    if len(cols) == 1:
        for c in cols:
            return cols[c]
    return None


def many_to_one(target_entity_cls, tow_way=True, ondelete="CASCADE", onupdate="NO ACTION", lazy='dynamic'):
    """
    参考https://wtforms-alchemy.readthedocs.io/en/latest/relationships.html#one-to-many-relations
    多对一，一对多共用这种形式

    一旦映射成功，多的一方可获取对方``实体名小写``变量，一的一方可获取自己``实体名小写_set``Query对象，可进行任意过滤查询

    :param lazy: 是否延迟加载，不是直接得到结果，而是生成Query对象，可进行一切过滤操作
    :param tow_way: 是否同时建立one_to_many关系, 一旦映射成功，one的一方将自动拥有many一方集合变量名`*_set`
    :param onupdate: 级联更新 ``CASCADE, DELETE and RESTRICT``
    :param ondelete: 级联删除 ``CASCADE, DELETE and RESTRICT``
    :param target_entity_cls: 父实体类
    :return:
    """

    def ref_table(cls):
        pk_ = get_ref_column(target_entity_cls)
        if pk_ is None:
            return cls
        setattr(cls, pk_.key,
                Column(pk_.type, ForeignKey(pk_, ondelete=ondelete, onupdate=onupdate), nullable=True))
        field_name = target_entity_cls.__name__.lower()
        if tow_way:
            # parent, children_set
            setattr(cls, field_name, relationship(target_entity_cls,
                                                  backref=backref("{0}_set".format(cls.__name__.lower()), lazy=lazy)))
        else:
            # parent
            setattr(cls, field_name, relationship(target_entity_cls))

        return cls

    return ref_table


def one_to_one(target_entity_cls, ondelete="CASCADE", onupdate="NO ACTION"):
    """
    一对一关联
    一旦映射成功，彼此皆可获取对方``实体名小写``的变量
    :param target_entity_cls: 父实体类
    :param ondelete: 级联更新 ``CASCADE, DELETE and RESTRICT NO ACTION SET NULL..``
    :param onupdate: 级联删除 ``CASCADE, DELETE and RESTRICT NO ACTION SET NULL..``
    :return: `class`
    """

    def ref_table(cls):
        pk_ = get_ref_column(target_entity_cls)
        if pk_ is None:
            return cls
        field_name = target_entity_cls.__name__.lower()
        setattr(cls, pk_.key, Column(pk_.type, ForeignKey(pk_, ondelete=ondelete, onupdate=onupdate), nullable=True))
        # parent, children
        setattr(cls, field_name, relationship(target_entity_cls, backref=backref(cls.__name__.lower(), uselist=False)))

        return cls

    return ref_table


def many_to_many(target_entity_cls, ondelete="CASCADE", onupdate="NO ACTION", lazy='dynamic'):
    """
    多对多，装饰到有关联关系的任意实体之上
    一旦映射成功，彼此皆可获取对方``实体名小写_set``Query对象，可进行任意过滤查询
    :param target_entity_cls: 映射的实体类
    :param ondelete: 级联删除 ``CASCADE, DELETE and RESTRICT NO ACTION SET NULL..``
    :param onupdate: 级联更新 ``CASCADE, DELETE and RESTRICT NO ACTION SET NULL..``
    :param lazy: 是否延迟加载，不是直接得到结果，而是生成Query对象，可进行一切过滤操作
    :return: `class`
    """

    def ref_table(cls):
        cls_pk = get_ref_column(cls)
        if cls_pk is None:
            return cls
        target_cls_pk = get_ref_column(target_entity_cls)
        if target_cls_pk is None:
            return cls
        target_name = target_entity_cls.__name__.lower()
        self_name = cls.__name__.lower()
        association_table = Table(
            '{0}{1}_{2}_relationships'.format(settings.TABLE_NAME_PREFIX, self_name, target_name),
            Entity.metadata,
            Column(target_cls_pk.key, target_cls_pk.type,
                   ForeignKey(target_cls_pk, ondelete=ondelete, onupdate=onupdate)),
            Column(cls_pk.key, cls_pk.type, ForeignKey(cls_pk, ondelete=ondelete, onupdate=onupdate))
        )
        # parent_set, children_set
        setattr(cls, "{0}_set".format(target_name),
                relationship(target_entity_cls, secondary=association_table,
                             backref=backref("{0}_set".format(self_name), lazy=lazy), lazy='dynamic'))
        return cls

    return ref_table


# 特别的数据库处理函数

# 计算两个日期之间以分钟为单位的差值，返回整数
class MinuteDiff(FunctionElement):
    type = Integer()
    name = "minute_diff"


@compiles(MinuteDiff, 'mssql')
def _mssql_minute_diff(element, compiler, **kw):
    return "DATEDIFF(MINUTE, %s, %s)" % (compiler.process(element.clauses.clauses[0]),
                                         compiler.process(element.clauses.clauses[1]))


@compiles(MinuteDiff, 'mysql')
def _mysql_minute_diff(element, compiler, **kw):
    return "TIMESTAMPDIFF(MINUTE, %s, %s)" % (compiler.process(element.clauses.clauses[0]),
                                              compiler.process(element.clauses.clauses[1]))


@compiles(MinuteDiff, 'sqlite')
def _mysql_minute_diff(element, compiler, **kw):
    return "(julianday(%s) - julianday(%s))*24*60" % (compiler.process(element.clauses.clauses[0]),
                                                      compiler.process(element.clauses.clauses[1]))
